{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Key Value Database\n",
    "\n",
    "This is an interactive tutorial of an Encrypted Key Value Database. The database allows for three operations, **Insert, Replace, and Query**. All the operations are implemented as fully-homomorphic encrypted circuits.\n",
    "\n",
    "In `frontends/concrete-python/examples/key_value_database/static_size.py`, you can find the full implementation\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below are the import statements.\n",
    "\n",
    "**time:** Used for measuring the time to create keys, encrypt and run circuits.\n",
    "\n",
    "**concrete.numpy:** Used for implementing homomorphic circuits.\n",
    "\n",
    "**numpy:** Used for mathematical operations. Concrete library compiles numpy operations into FHE encrypted operations.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "from concrete import fhe\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below are the database configuration parameters. \n",
    "\n",
    "**Number of Entries:** Defines the maximum number of insertable (key, value) pairs. \n",
    "\n",
    "**Chunk Size:** Defines the size of each chunk. Chunks are used as the smallest substructure of key and values.\n",
    "\n",
    "**Key Size:** Defines the size of each key.\n",
    "\n",
    "**Value Size:** Defines the size of each value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The number of entries in the database\n",
    "NUMBER_OF_ENTRIES = 5\n",
    "# The number of bits in each chunk\n",
    "CHUNK_SIZE = 4\n",
    "\n",
    "# The number of bits in the key and value\n",
    "KEY_SIZE = 32\n",
    "VALUE_SIZE = 32"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below are the definition of the state, and the accessors/indexers to the state.\n",
    "\n",
    "The shape of the state is defined with respect to the size of the key/value with the table given below.\n",
    "\n",
    "| Flag Size | Key Size | Number of Key Chunks | Value Size | Number of Value Chunks |\n",
    "| --- | --- | --- | --- | --- |\n",
    "| 1         | 32       | 32/4 = 8                   | 32         | 32/4 = 8                      |\n",
    "| 1         | 8        | 8/4 = 2                    | 16          | 16/4 = 4                       |\n",
    "| 1         | 4        | 4/4 = 1                    | 4          | 4/4 = 1                       |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Key and Value size must be a multiple of chunk size\n",
    "assert KEY_SIZE % CHUNK_SIZE == 0\n",
    "assert VALUE_SIZE % CHUNK_SIZE == 0\n",
    "\n",
    "# Required number of chunks to store keys and values\n",
    "NUMBER_OF_KEY_CHUNKS = KEY_SIZE // CHUNK_SIZE\n",
    "NUMBER_OF_VALUE_CHUNKS = VALUE_SIZE // CHUNK_SIZE\n",
    "\n",
    "# The shape of the state as a tensor\n",
    "# Shape:\n",
    "# | Flag Size | Key Size | Value Size |\n",
    "# | 1         | 32/4 = 8 | 32/4 = 8   |\n",
    "STATE_SHAPE = (NUMBER_OF_ENTRIES, 1 + NUMBER_OF_KEY_CHUNKS + NUMBER_OF_VALUE_CHUNKS)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Slices below are used to index certain parts of the the state. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Indexers for each part of the state\n",
    "FLAG = 0\n",
    "KEY = slice(1, 1 + NUMBER_OF_KEY_CHUNKS)\n",
    "VALUE = slice(1 + NUMBER_OF_KEY_CHUNKS, None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Encode/Decode functions.\n",
    "\n",
    "Encode/Decode functions are used to convert between integers and numpy arrays. The interface exposes integers, but the state is stored and processed as a numpy array."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Encode\n",
    "\n",
    "Encodes a number into a numpy array.\n",
    "\n",
    "- The number is encoded in binary and then split into chunks.\n",
    "- Each chunk is then converted to an integer\n",
    "- The integers are then stored in a numpy array\n",
    "\n",
    "| Function Call | Input(Integer) | Array-Width | Result(Numpy Array) |\n",
    "| --- | --- | --- | --- |\n",
    "| encode(25, 4) | 25 | 4 | [0, 0, 1, 9] |\n",
    "| encode(40, 4) | 40 | 4 | [0, 0, 2, 8] |\n",
    "| encode(11, 3) | 11 | 3 | [0, 0, 11] |\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode(number: int, width: int) -> np.array:\n",
    "    binary_repr = np.binary_repr(number, width=width)\n",
    "    blocks = [binary_repr[i : i + CHUNK_SIZE] for i in range(0, len(binary_repr), CHUNK_SIZE)]\n",
    "    return np.array([int(block, 2) for block in blocks])\n",
    "\n",
    "\n",
    "# Encode a number with the key size\n",
    "def encode_key(number: int) -> np.array:\n",
    "    return encode(number, width=KEY_SIZE)\n",
    "\n",
    "\n",
    "# Encode a number with the value size\n",
    "def encode_value(number: int) -> np.array:\n",
    "    return encode(number, width=VALUE_SIZE)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Decode\n",
    "\n",
    "Decodes a numpy array into a number.\n",
    "\n",
    "| Function Call | Input(Numpy Array) | Result(Integer) |\n",
    "| --- | --- | --- |\n",
    "| decode([0, 0, 1, 9]) | [0, 0, 1, 9] | 25 |\n",
    "| decode([0, 0, 2, 8]) | [0, 0, 2, 8] | 40 |\n",
    "| decode([0, 0, 11]) | [0, 0, 11] | 11 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode(encoded_number: np.array) -> int:\n",
    "    result = 0\n",
    "    for i in range(len(encoded_number)):\n",
    "        result += 2 ** (CHUNK_SIZE * i) * encoded_number[(len(encoded_number) - i) - 1]\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Row Selection with Table Lookups\n",
    "\n",
    "Keep selected function is used to select the correct row of the database for each operation.\n",
    "\n",
    "Below is the python definition of the function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def keep_selected(value, selected):\n",
    "    if selected:\n",
    "        return value\n",
    "    return 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This function takes any value, and a boolean flag that indicates if value is selected or not. Within homomorphic encryption circuits, we cannot compile this function as encrypted values cannot affect control flow. Instead, we turn this function into a lookup table.\n",
    "\n",
    "Selected is prepended to the value, and function is modified to act as below.\n",
    "\n",
    "`keep_selected(i=0..15, 1) -> 0` \n",
    "`keep_selected(i=16..31, 0) -> i-16`\n",
    "\n",
    "Below is the python code for the lookup table."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "keep_selected_lut = fhe.LookupTable([0 for _ in range(16)] + [i for i in range(16)])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Most significant bit of the input to the lookup table represents the select bit, hence if `select=0 <=> i=0..15` then the output is `0`. If `select=1 <=> i=16..31` then the output is `i-16`, the value itself.\n",
    "\n",
    "To summarize, we could implement the keep_selected function as below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def keep_selected_using_lut(value, selected):\n",
    "    packed = (2**CHUNK_SIZE) * selected + value\n",
    "    return keep_selected_lut[packed]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "### Circuit Implementation Functions\n",
    "\n",
    "The following functions are used to implement the key-value database circuits. \n",
    "Three circuits are implemented: \n",
    "- insert: Inserts a key value pair into the database\n",
    "- replace: Replaces the value of a key in the database\n",
    "- query: Queries the database for a key and returns the value\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Insert\n",
    "\n",
    "Algorithm of the insert function is as follows:\n",
    "- Create a selection array to select a certain row of the database\n",
    "- Fill this array by setting the first non-empty row of the database to 1\n",
    "- Create a state update array, where the first non-empty row of the database is set to the new key and value\n",
    "- Add the state update array to the state\n",
    "\n",
    "Implementation is below. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Insert a key value pair into the database\n",
    "# - state: The state of the database\n",
    "# - key: The key to insert\n",
    "# - value: The value to insert\n",
    "# Returns the updated state\n",
    "def _insert_impl(state, key, value):\n",
    "    # Get the used bit from the state\n",
    "    # This bit is used to determine if an entry is used or not\n",
    "    flags = state[:, FLAG]\n",
    "\n",
    "    # Create a selection array\n",
    "    # This array is used to select the first unused entry\n",
    "    selection = fhe.zeros(NUMBER_OF_ENTRIES)\n",
    "\n",
    "    # The found bit is used to determine if an unused entry has been found\n",
    "    found = fhe.zero()\n",
    "    for i in range(NUMBER_OF_ENTRIES):\n",
    "        # The packed flag and found bit are used to determine if the entry is unused\n",
    "        # | Flag | Found |\n",
    "        # | 0    | 0     | -> Unused, select\n",
    "        # | 0    | 1     | -> Unused, skip\n",
    "        # | 1    | 0     | -> Used, skip\n",
    "        # | 1    | 1     | -> Used, skip\n",
    "        packed_flag_and_found = (found * 2) + flags[i]\n",
    "        # Use the packed flag and found bit to determine if the entry is unused\n",
    "        is_selected = packed_flag_and_found == 0\n",
    "\n",
    "        # Update the selection array\n",
    "        selection[i] = is_selected\n",
    "        # Update the found bit, so all entries will be\n",
    "        # skipped after the first unused entry is found\n",
    "        found += is_selected\n",
    "\n",
    "    # Create a state update array\n",
    "    state_update = fhe.zeros(STATE_SHAPE)\n",
    "    # Update the state update array with the selection array\n",
    "    state_update[:, FLAG] = selection\n",
    "\n",
    "    # Reshape the selection array to be able to use it as an index\n",
    "    selection = selection.reshape((-1, 1))\n",
    "\n",
    "    # Create a packed selection and key array\n",
    "    # This array is used to update the key of the selected entry\n",
    "    packed_selection_and_key = (selection * (2**CHUNK_SIZE)) + key\n",
    "    key_update = keep_selected_lut[packed_selection_and_key]\n",
    "\n",
    "    # Create a packed selection and value array\n",
    "    # This array is used to update the value of the selected entry\n",
    "    packed_selection_and_value = selection * (2**CHUNK_SIZE) + value\n",
    "    value_update = keep_selected_lut[packed_selection_and_value]\n",
    "\n",
    "    # Update the state update array with the key and value update arrays\n",
    "    state_update[:, KEY] = key_update\n",
    "    state_update[:, VALUE] = value_update\n",
    "\n",
    "    # Update the state with the state update array\n",
    "    new_state = state + state_update\n",
    "    return new_state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Replace\n",
    "\n",
    "Algorithm of the replace function is as follows:\n",
    "- Create a equal-rows array to select rows that match the given key in the database\n",
    "- Create a selection array to select the row that is currently used in the database\n",
    "- Set the selection array to 1 for the row that contains the key, and 0 for all other rows\n",
    "- Create an inverse selection array by inverting the selection array\n",
    "- Row set to 1 in the selection array will be updated, whereas all other values will stay the same\n",
    "- To do this, we multiply the selection array with the new key and value, and the inverse selection array with the old key and value\n",
    "- We then add the two arrays to get the new state"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Replace the value of a key in the database\n",
    "#   If the key is not in the database, nothing happens\n",
    "#   If the key is in the database, the value is replaced\n",
    "# - state: The state of the database\n",
    "# - key: The key to replace\n",
    "# - value: The value to replace\n",
    "# Returns the updated state\n",
    "def _replace_impl(state, key, value):\n",
    "    # Get the flags, keys and values from the state\n",
    "    flags = state[:, FLAG]\n",
    "    keys = state[:, KEY]\n",
    "    values = state[:, VALUE]\n",
    "\n",
    "    # Create an equal_rows array\n",
    "    # This array is used to select all entries with the given key\n",
    "    # The equal_rows array is created by comparing the keys in the state\n",
    "    # with the given key, and only setting the entry to 1 if the keys are equal\n",
    "    # Example:\n",
    "    #   keys = [[1, 0, 1, 0], [0, 1, 0, 1, 1]]\n",
    "    #   key = [1, 0, 1, 0]\n",
    "    #   equal_rows = [1, 0]\n",
    "    equal_rows = np.sum((keys - key) == 0, axis=1) == NUMBER_OF_KEY_CHUNKS\n",
    "\n",
    "    # Create a selection array\n",
    "    # This array is used to select the entry to change the value of\n",
    "    # The selection array is created by combining the equal_rows array\n",
    "    # with the flags array, which is used to determine if an entry is used or not\n",
    "    # The reason for combining the equal_rows array with the flags array\n",
    "    # is to make sure that only used entries are selected\n",
    "    selection = (flags * 2 + equal_rows == 3).reshape((-1, 1))\n",
    "\n",
    "    # Create a packed selection and value array\n",
    "    # This array is used to update the value of the selected entry\n",
    "    packed_selection_and_value = selection * (2**CHUNK_SIZE) + value\n",
    "    set_value = keep_selected_lut[packed_selection_and_value]\n",
    "\n",
    "    # Create an inverse selection array\n",
    "    # This array is used to pick entries that are not selected\n",
    "    # Example:\n",
    "    #   selection = [1, 0, 0]\n",
    "    #   inverse_selection = [0, 1, 1]\n",
    "    inverse_selection = 1 - selection\n",
    "\n",
    "    # Create a packed inverse selection and value array\n",
    "    # This array is used to keep the value of the entries that are not selected\n",
    "    packed_inverse_selection_and_values = inverse_selection * (2**CHUNK_SIZE) + values\n",
    "    kept_values = keep_selected_lut[packed_inverse_selection_and_values]\n",
    "\n",
    "    # Update the values of the state with the new values\n",
    "    new_values = kept_values + set_value\n",
    "    state[:, VALUE] = new_values\n",
    "\n",
    "    return state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Query\n",
    "\n",
    "Algorithm of the query function is as follows:\n",
    "- Create a selection array to select a certain row of the database\n",
    "- Set the selection array to 1 for the row that contains the key\n",
    "- Multiply the selection array with the state to zero all rows that do not contain the key\n",
    "- Sum the rows of the state to get the remaining non-zero row, basically doing a dimension reduction\n",
    "- Prepend the found flag to the value, return the resulting array.\n",
    "- The resulting array will be destructured in the non-encrypted query function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Query the database for a key and return the value\n",
    "# - state: The state of the database\n",
    "# - key: The key to query\n",
    "# Returns an array with the following format:\n",
    "#   [found, value]\n",
    "#   found: 1 if the key was found, 0 otherwise\n",
    "#   value: The value of the key if the key was found, 0 otherwise\n",
    "def _query_impl(state, key):\n",
    "    # Get the keys and values from the state\n",
    "    keys = state[:, KEY]\n",
    "    values = state[:, VALUE]\n",
    "\n",
    "    # Create a selection array\n",
    "    # This array is used to select the entry with the given key\n",
    "    # The selection array is created by comparing the keys in the state\n",
    "    # with the given key, and only setting the entry to 1 if the keys are equal\n",
    "    # Example:\n",
    "    #   keys = [[1, 0, 1, 0], [0, 1, 0, 1, 1]]\n",
    "    #   key = [1, 0, 1, 0]\n",
    "    #   selection = [1, 0]\n",
    "    selection = (np.sum((keys - key) == 0, axis=1) == NUMBER_OF_KEY_CHUNKS).reshape((-1, 1))\n",
    "\n",
    "    # Create a found bit\n",
    "    # This bit is used to determine if the key was found\n",
    "    # The found bit is set to 1 if the key was found, and 0 otherwise\n",
    "    found = np.sum(selection)\n",
    "\n",
    "    # Create a packed selection and value array\n",
    "    # This array is used to get the value of the selected entry\n",
    "    packed_selection_and_values = selection * (2**CHUNK_SIZE) + values\n",
    "    value_selection = keep_selected_lut[packed_selection_and_values]\n",
    "\n",
    "    # Sum the value selection array to get the value\n",
    "    value = np.sum(value_selection, axis=0)\n",
    "\n",
    "    # Return the found bit and the value\n",
    "    return fhe.array([found, *value])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Key-Value Database Interface\n",
    "\n",
    "KeyValueDatabase class is the interface that exposes the functionality of the key-value database."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "class KeyValueDatabase:\n",
    "    \"\"\"\n",
    "    A key-value database that uses fully homomorphic encryption circuits to store the data.\n",
    "    \"\"\"\n",
    "\n",
    "    # The state of the database, it holds all the\n",
    "    # keys and values as a table of entries\n",
    "    _state: np.ndarray\n",
    "\n",
    "    # The circuits used to implement the database\n",
    "    _insert_circuit: fhe.Circuit\n",
    "    _replace_circuit: fhe.Circuit\n",
    "    _query_circuit: fhe.Circuit\n",
    "\n",
    "    # Below is the initialization of the database.\n",
    "\n",
    "    # First, we initialize the state, and provide the necessary input sets.\n",
    "    # In versions later than concrete-numpy.0.9.0, we can use the `direct circuit`\n",
    "    # functionality to define the bit-widths of encrypted values rather than using\n",
    "    # `input sets`. Input sets are used to determine the required bit-width of the\n",
    "    # encrypted values. Hence, we add the largest possible value in the database\n",
    "    # to the input sets.\n",
    "\n",
    "    # Within the initialization phase, we create the required configuration,\n",
    "    # compilers, circuits, and keys. Circuit and key generation phase is\n",
    "    # timed and printed in the output.\n",
    "\n",
    "    def __init__(self):\n",
    "        # Initialize the state to all zeros\n",
    "        self._state = np.zeros(STATE_SHAPE, dtype=np.int64)\n",
    "\n",
    "        sample_state = np.array(\n",
    "            [\n",
    "                [i % 2] + encode_key(i).tolist() + encode_value(i).tolist()\n",
    "                for i in range(STATE_SHAPE[0])\n",
    "            ]\n",
    "        )\n",
    "\n",
    "        insert_replace_inputset = [\n",
    "            (\n",
    "                # state\n",
    "                sample_state,\n",
    "                # key\n",
    "                encode_key(i),\n",
    "                # value\n",
    "                encode_key(i),\n",
    "            )\n",
    "            for i in range(20)\n",
    "        ]\n",
    "        query_inputset = [\n",
    "            (\n",
    "                # state\n",
    "                sample_state,\n",
    "                # key\n",
    "                encode_key(i),\n",
    "            )\n",
    "            for i in range(20)\n",
    "        ]\n",
    "\n",
    "        ## Circuit compilation\n",
    "\n",
    "        # Create a configuration for the compiler\n",
    "        configuration = fhe.Configuration(\n",
    "            enable_unsafe_features=True,\n",
    "            use_insecure_key_cache=True,\n",
    "            insecure_key_cache_location=\".keys\",\n",
    "        )\n",
    "\n",
    "        # Create the compilers for the circuits\n",
    "        # Each compiler is provided with\n",
    "        # - The implementation of the circuit\n",
    "        # - The inputs and their corresponding types of the circuit\n",
    "        #  - \"encrypted\": The input is encrypted\n",
    "        #  - \"plain\": The input is not encrypted\n",
    "        insert_compiler = fhe.Compiler(\n",
    "            _insert_impl, {\"state\": \"encrypted\", \"key\": \"encrypted\", \"value\": \"encrypted\"}\n",
    "        )\n",
    "        replace_compiler = fhe.Compiler(\n",
    "            _replace_impl, {\"state\": \"encrypted\", \"key\": \"encrypted\", \"value\": \"encrypted\"}\n",
    "        )\n",
    "        query_compiler = fhe.Compiler(_query_impl, {\"state\": \"encrypted\", \"key\": \"encrypted\"})\n",
    "\n",
    "        ## Compile the circuits\n",
    "        # The circuits are compiled with the input set and the configuration\n",
    "\n",
    "        print()\n",
    "\n",
    "        print(\"Compiling insertion circuit...\")\n",
    "        start = time.time()\n",
    "        self._insert_circuit = insert_compiler.compile(insert_replace_inputset, configuration)\n",
    "        end = time.time()\n",
    "        print(f\"(took {end - start:.3f} seconds)\")\n",
    "\n",
    "        print()\n",
    "\n",
    "        print(\"Compiling replacement circuit...\")\n",
    "        start = time.time()\n",
    "        self._replace_circuit = replace_compiler.compile(insert_replace_inputset, configuration)\n",
    "        end = time.time()\n",
    "        print(f\"(took {end - start:.3f} seconds)\")\n",
    "\n",
    "        print()\n",
    "\n",
    "        print(\"Compiling query circuit...\")\n",
    "        start = time.time()\n",
    "        self._query_circuit = query_compiler.compile(query_inputset, configuration)\n",
    "        end = time.time()\n",
    "        print(f\"(took {end - start:.3f} seconds)\")\n",
    "\n",
    "        print()\n",
    "\n",
    "        ## Generate the keys for the circuits\n",
    "        # The keys are seaparately generated for each circuit\n",
    "\n",
    "        print(\"Generating insertion keys...\")\n",
    "        start = time.time()\n",
    "        self._insert_circuit.keygen()\n",
    "        end = time.time()\n",
    "        print(f\"(took {end - start:.3f} seconds)\")\n",
    "\n",
    "        print()\n",
    "\n",
    "        print(\"Generating replacement keys...\")\n",
    "        start = time.time()\n",
    "        self._replace_circuit.keygen()\n",
    "        end = time.time()\n",
    "        print(f\"(took {end - start:.3f} seconds)\")\n",
    "\n",
    "        print()\n",
    "\n",
    "        print(\"Generating query keys...\")\n",
    "        start = time.time()\n",
    "        self._query_circuit.keygen()\n",
    "        end = time.time()\n",
    "        print(f\"(took {end - start:.3f} seconds)\")\n",
    "\n",
    "    ### The Interface Functions\n",
    "\n",
    "    # The following methods are used to interact with the database.\n",
    "    # They are used to insert, replace and query the database.\n",
    "    # The methods are implemented by encrypting the inputs,\n",
    "    # running the circuit and decrypting the output.\n",
    "\n",
    "    # Insert a key-value pair into the database\n",
    "    # - key: The key to insert\n",
    "    # - value: The value to insert\n",
    "    # The key and value are encoded before they are inserted\n",
    "    # The state of the database is updated with the new key-value pair\n",
    "    def insert(self, key, value):\n",
    "        print()\n",
    "        print(\"Inserting...\")\n",
    "        start = time.time()\n",
    "        self._state = self._insert_circuit.encrypt_run_decrypt(\n",
    "            self._state, encode_key(key), encode_value(value)\n",
    "        )\n",
    "        end = time.time()\n",
    "        print(f\"(took {end - start:.3f} seconds)\")\n",
    "\n",
    "    # Replace a key-value pair in the database\n",
    "    # - key: The key to replace\n",
    "    # - value: The new value to insert with the key\n",
    "    # The key and value are encoded before they are inserted\n",
    "    # The state of the database is updated with the new key-value pair\n",
    "    def replace(self, key, value):\n",
    "        print()\n",
    "        print(\"Replacing...\")\n",
    "        start = time.time()\n",
    "        self._state = self._replace_circuit.encrypt_run_decrypt(\n",
    "            self._state, encode_key(key), encode_value(value)\n",
    "        )\n",
    "        end = time.time()\n",
    "        print(f\"(took {end - start:.3f} seconds)\")\n",
    "\n",
    "    # Query the database for a key\n",
    "    # - key: The key to query\n",
    "    # The key is encoded before it is queried\n",
    "    # Returns the value associated with the key or None if the key is not found\n",
    "    def query(self, key):\n",
    "        print()\n",
    "        print(\"Querying...\")\n",
    "        start = time.time()\n",
    "        result = self._query_circuit.encrypt_run_decrypt(self._state, encode_key(key))\n",
    "        end = time.time()\n",
    "        print(f\"(took {end - start:.3f} seconds)\")\n",
    "\n",
    "        if result[0] == 0:\n",
    "            return None\n",
    "\n",
    "        return decode(result[1:])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The implementation provided above is the statically-sized implementation. We will shortly discuss the dynamic implementation below.\n",
    "\n",
    "Whereas static implementation works with circuits over the whole database, dynamic implementation works with circuits over a single row of the database.\n",
    "\n",
    "In the dynamic implementation, we iterate over the rows of the database in a simple Python loop, and run the circuits over each row. This difference in implementation is reflected in the `insert`, `replace` and `query` functions.\n",
    "\n",
    "In terms of comparison of the implementations, the static implementation is more efficient with dense databases as it works with parallelized tensors, but it takes the same amount of time to query an empty database and a database with 1 million entries. The dynamic implementation is more efficient with sparse databases as it grows with the number of entries, but it doesn't use circuit level parallelization. Also, insertion is free in the dynamic implementation as it only appends a new item to a Python list."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have now finished the definition of the database. We can now use the database to insert, replace and query values."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Usage\n",
    "\n",
    "Below is the initialization of the database. As we provide parameters globally, we can simply initialize the database with the following command."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Compiling insertion circuit...\n",
      "(took 1.178 seconds)\n",
      "\n",
      "Compiling replacement circuit...\n",
      "(took 0.626 seconds)\n",
      "\n",
      "Compiling query circuit...\n",
      "(took 0.603 seconds)\n",
      "\n",
      "Generating insertion keys...\n",
      "(took 0.188 seconds)\n",
      "\n",
      "Generating replacement keys...\n",
      "(took 0.280 seconds)\n",
      "\n",
      "Generating query keys...\n",
      "(took 0.227 seconds)\n"
     ]
    }
   ],
   "source": [
    "## Test: Initialization\n",
    "# Initialize the database\n",
    "db = KeyValueDatabase()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use the interface functions as provided below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Inserting...\n",
      "(took 0.768 seconds)\n"
     ]
    }
   ],
   "source": [
    "# Test: Insert/Query\n",
    "# Insert (key: 3, value: 4) into the database\n",
    "db.insert(3, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Querying...\n",
      "(took 0.460 seconds)\n"
     ]
    }
   ],
   "source": [
    "# Query the database for the key 3\n",
    "# The value 4 should be returned\n",
    "assert db.query(3) == 4"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Replacing...\n",
      "(took 0.806 seconds)\n"
     ]
    }
   ],
   "source": [
    "# Test: Replace/Query\n",
    "# Replace the value of the key 3 with 1\n",
    "db.replace(3, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Querying...\n",
      "(took 0.483 seconds)\n"
     ]
    }
   ],
   "source": [
    "# Query the database for the key 3\n",
    "# The value 1 should be returned\n",
    "assert db.query(3) == 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Inserting...\n",
      "(took 0.618 seconds)\n"
     ]
    }
   ],
   "source": [
    "# Test: Insert/Query\n",
    "# Insert (key: 25, value: 40) into the database\n",
    "db.insert(25, 40)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Querying...\n",
      "(took 0.957 seconds)\n"
     ]
    }
   ],
   "source": [
    "# Query the database for the key 25\n",
    "# The value 40 should be returned\n",
    "assert db.query(25) == 40"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Querying...\n",
      "(took 1.133 seconds)\n"
     ]
    }
   ],
   "source": [
    "# Test: Query Not Found\n",
    "# Query the database for the key 4\n",
    "# None should be returned\n",
    "assert db.query(4) is None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Replacing...\n",
      "(took 2.325 seconds)\n"
     ]
    }
   ],
   "source": [
    "# Test: Replace/Query\n",
    "# Replace the value of the key 3 with 5\n",
    "db.replace(3, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Querying...\n",
      "(took 1.172 seconds)\n"
     ]
    }
   ],
   "source": [
    "# Query the database for the key 3\n",
    "# The value 5 should be returned\n",
    "assert db.query(3) == 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can now test the limits, we'll use the hyper-parameters `KEY_SIZE` and `VALUE_SIZE` in order to ensure that the examples work robustly against changes to the parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define lower/upper bounds for the key\n",
    "minimum_key = 1\n",
    "maximum_key = 2**KEY_SIZE - 1\n",
    "# Define lower/upper bounds for the value\n",
    "minimum_value = 1\n",
    "maximum_value = 2**VALUE_SIZE - 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Below are the usage examples with these bounds."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Inserting...\n",
      "(took 1.358 seconds)\n",
      "\n",
      "Querying...\n",
      "(took 1.137 seconds)\n",
      "\n",
      "Inserting...\n",
      "(took 1.383 seconds)\n",
      "\n",
      "Querying...\n",
      "(took 1.207 seconds)\n",
      "\n",
      "Replacing...\n",
      "(took 2.404 seconds)\n",
      "\n",
      "Querying...\n",
      "(took 1.241 seconds)\n",
      "\n",
      "Replacing...\n",
      "(took 2.345 seconds)\n",
      "\n",
      "Querying...\n",
      "(took 1.213 seconds)\n"
     ]
    }
   ],
   "source": [
    "# Test: Insert/Replace/Query Bounds\n",
    "# Insert (key: minimum_key , value: minimum_value) into the database\n",
    "db.insert(minimum_key, minimum_value)\n",
    "\n",
    "# Query the database for the key=minimum_key\n",
    "# The value minimum_value should be returned\n",
    "assert db.query(minimum_key) == minimum_value\n",
    "\n",
    "# Insert (key: maximum_key , value: maximum_value) into the database\n",
    "db.insert(maximum_key, maximum_value)\n",
    "\n",
    "# Query the database for the key=maximum_key\n",
    "# The value maximum_value should be returned\n",
    "assert db.query(maximum_key) == maximum_value\n",
    "\n",
    "# Replace the value of key=minimum_key with maximum_value\n",
    "db.replace(minimum_key, maximum_value)\n",
    "\n",
    "# Query the database for the key=minimum_key\n",
    "# The value maximum_value should be returned\n",
    "assert db.query(minimum_key) == maximum_value\n",
    "\n",
    "# Replace the value of key=maximum_key with minimum_value\n",
    "db.replace(maximum_key, minimum_value)\n",
    "\n",
    "# Query the database for the key=maximum_key\n",
    "# The value minimum_value should be returned\n",
    "assert db.query(maximum_key) == minimum_value"
   ]
  }
 ],
 "metadata": {
  "execution": {
   "timeout": 10800
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
